{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28e6e614-e360-4292-965e-0d255027e9b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2021 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b88dc1a-a92d-44cc-9fb7-d9e2ef20c8e2",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# Accelerating HuggingFace T5 Inference with TensorRT\n",
    "\n",
    "T5 is an encoder-decoder model that converts all NLP problems into a text-to-text format. More specifically, it does so by encoding  different tasks as text directives in the input stream. This enables a single model to be trained supervised on a wide variety of NLP tasks such as translation, classification, Q&A and summarization.\n",
    "\n",
    "This notebook shows 3 easy steps to convert a [HuggingFace PyTorch T5 model](https://huggingface.co/transformers/model_doc/t5.html) to a TensorRT engine for high-performance inference.\n",
    "\n",
    "1. [Download HuggingFace T5 model](#1)\n",
    "1. [Convert to ONNX format](#2)\n",
    "1. [Convert to TensorRT engine](#3)\n",
    "\n",
    "## Prerequisite\n",
    "\n",
    "Follow the instruction at https://github.com/NVIDIA/TensorRT to build the TensorRT-OSS docker container required to run this notebook.\n",
    "\n",
    "Next, we install some extra dependencies, then restart the kernel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0c36ecb7-c622-4d95-a851-b9a6eb18e81b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip3 install -r ../requirements.txt\n",
    "\n",
    "import IPython\n",
    "app = IPython.Application.instance()\n",
    "app.kernel.do_shutdown(True)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "235d2f1b-439e-4cd0-8286-1d63a13f2cf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "ROOT_DIR = os.path.abspath(\"../\")\n",
    "sys.path.append(ROOT_DIR)\n",
    "\n",
    "import torch\n",
    "import tensorrt as trt\n",
    "\n",
    "# huggingface\n",
    "from transformers import (\n",
    "    T5ForConditionalGeneration,\n",
    "    T5Tokenizer,\n",
    "    T5Config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af4254e2-11fd-4bc7-ac0b-60b1a9e07c4e",
   "metadata": {},
   "source": [
    "<a id=\"1\"></a>\n",
    "\n",
    "## 1. Download HuggingFace T5 model\n",
    "\n",
    "First, we download the original HuggingFace PyTorch T5 model from HuggingFace model hubs, together with its associated tokernizer.\n",
    "\n",
    "The T5 variants  that are suported by TensorRT 8 are:  t5-small (60M), t5-base (220M), t5-large (770M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fae66d58-f994-4987-8f1d-1fa8ac2ec8b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "T5_VARIANT = 't5-small' # choices: t5-small | t5-base | t5-large\n",
    "\n",
    "t5_model = T5ForConditionalGeneration.from_pretrained(T5_VARIANT)\n",
    "tokenizer = T5Tokenizer.from_pretrained(T5_VARIANT)\n",
    "config = T5Config(T5_VARIANT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7252ca90-1104-40dc-8e72-f51c07a4cd11",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pytorch Model saved to ./models/t5-small/pytorch\n"
     ]
    }
   ],
   "source": [
    "# save model locally\n",
    "pytorch_model_dir = './models/{}/pytorch'.format(T5_VARIANT)\n",
    "!mkdir -p $pytorch_model_dir\n",
    "\n",
    "t5_model.save_pretrained(pytorch_model_dir)\n",
    "print(\"Pytorch Model saved to {}\".format(pytorch_model_dir))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11ea023d-c4d4-43bb-9d77-c76684e0b06f",
   "metadata": {},
   "source": [
    "### Inference with PyTorch model\n",
    "\n",
    "Next, we will carry out inference with the PyTorch model.\n",
    "\n",
    "#### Single example inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bc45d9bc-b6ef-485e-8832-6628c292e315",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tokenizer(\"translate English to German: That is good.\", return_tensors=\"pt\")\n",
    "\n",
    "# inference on a single example\n",
    "t5_model.eval()\n",
    "with torch.no_grad():\n",
    "    outputs = t5_model(**inputs, labels=inputs[\"input_ids\"])\n",
    "\n",
    "logits = outputs.logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "98f7fd8b-2ee3-4d25-9204-7713eb7e90b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Das ist gut.\n"
     ]
    }
   ],
   "source": [
    "# Generate sequence for an input\n",
    "outputs = t5_model.to('cuda:0').generate(inputs.input_ids.to('cuda:0'))\n",
    "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "667fcacc-02cb-415d-a9ff-2d2ec44ef225",
   "metadata": {},
   "source": [
    "#### Model inference benchmark: encoder and decoder stacks\n",
    "\n",
    "For benchmarking purposes, we will employ a helper functions `encoder_inference` and `decoder_inference` which execute the inference repeatedly for the T5 encoder and decoder stacks separately, and measure end to end execution time. Let's take note of this execution time for comparison with TensorRT. \n",
    " \n",
    "`TimingProfile` is a named tuple that specifies the number of experiments and number of times to call the function per iteration (and number of warm-up calls although it is not used here)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "596ea542-d9e5-4367-b643-d60027fa05e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from T5.measurements import decoder_inference, encoder_inference, full_inference_greedy\n",
    "from T5.export import T5EncoderTorchFile, T5DecoderTorchFile\n",
    "from NNDF.networks import TimingProfile\n",
    "\n",
    "t5_torch_encoder = T5EncoderTorchFile.TorchModule(t5_model.encoder)\n",
    "t5_torch_decoder = T5DecoderTorchFile.TorchModule(\n",
    "    t5_model.decoder, t5_model.lm_head, t5_model.config\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "be755fbc-c53e-4f8d-a9c2-4817167cf93a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0072555395308882"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids = inputs.input_ids\n",
    "\n",
    "encoder_last_hidden_state, encoder_e2e_median_time = encoder_inference(\n",
    "    t5_torch_encoder, input_ids, TimingProfile(iterations=10, number=1, warmup=1)\n",
    ")\n",
    "encoder_e2e_median_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "960f05fc-f572-4832-ad82-8a75823866b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.011791097989771515"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_, decoder_e2e_median_time = decoder_inference(\n",
    "    t5_torch_decoder, input_ids, encoder_last_hidden_state, TimingProfile(iterations=10, number=1, warmup=1)\n",
    ")\n",
    "decoder_e2e_median_time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a99d5a06-a8f5-4ce7-a34c-bc42f07ac706",
   "metadata": {},
   "source": [
    "#### Full model inference and benchmark\n",
    "\n",
    "Next, we will try the T5 model for the task of translation from English to German.\n",
    "\n",
    "For benchmarking purposes, we will employ a helper function `full_inference_greedy` which executes the inference repeatedly and measures end to end execution time. Let's take note of this execution time for comparison with TensorRT. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "39d511cf-d963-4629-be54-22e9a258716d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0667644675122574"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from T5.T5ModelConfig import T5ModelTRTConfig\n",
    "\n",
    "decoder_output_greedy, full_e2e_median_runtime = full_inference_greedy(\n",
    "    t5_torch_encoder,\n",
    "    t5_torch_decoder,\n",
    "    input_ids,\n",
    "    tokenizer,\n",
    "    TimingProfile(iterations=10, number=1, warmup=1),\n",
    "    max_length=T5ModelTRTConfig.MAX_SEQUENCE_LENGTH[T5_VARIANT],\n",
    ")\n",
    "full_e2e_median_runtime"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cff48fc-b792-4852-b638-6e2c54099cb2",
   "metadata": {},
   "source": [
    "Let us decode the model's output back into text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "839bc6bc-65dc-499d-ac26-81456dbc1748",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Das ist gut.\n"
     ]
    }
   ],
   "source": [
    "# De-tokenize output to raw text\n",
    "print(tokenizer.decode(decoder_output_greedy[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d662701-e430-4fdc-ad46-1f296defcf8f",
   "metadata": {},
   "source": [
    "<a id=\"2\"></a>\n",
    "\n",
    "## 2. Convert to ONNX\n",
    "\n",
    "Prior to converting the model to a TensorRT engine, we will first convert the PyTorch model to an intermediate universal format.\n",
    "\n",
    "ONNX is an open format for machine learning and deep learning models. It allows you to convert deep learning and machine learning models from different frameworks such as TensorFlow, PyTorch, MATLAB, Caffe, and Keras to a single format.\n",
    "\n",
    "The steps to convert a PyTorch model to TensorRT are as follows:\n",
    "- Convert the pretrained image segmentation PyTorch model into ONNX.\n",
    "- Import the ONNX model into TensorRT.\n",
    "- Apply optimizations and generate an engine.\n",
    "- Perform inference on the GPU. \n",
    "\n",
    "For the T5 model, we will convert the encoder and decoder seperately."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c2b2be1a-021c-4f6c-957d-2ff7d1b95976",
   "metadata": {},
   "outputs": [],
   "source": [
    "# helpers\n",
    "from NNDF.networks import NetworkMetadata, Precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c50346f7-6c2c-4e4b-ba70-875688947b75",
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_model_path = './models/{}/ONNX'.format(T5_VARIANT)\n",
    "!mkdir -p $onnx_model_path\n",
    "\n",
    "metadata=NetworkMetadata(T5_VARIANT, Precision('fp16'), None)\n",
    "\n",
    "encoder_onnx_model_fpath = T5_VARIANT + \"-encoder.onnx\"\n",
    "decoder_onnx_model_fpath = T5_VARIANT + \"-decoder-with-lm-head.onnx\"\n",
    "\n",
    "t5_encoder = T5EncoderTorchFile(t5_model.to('cpu'), metadata)\n",
    "t5_decoder = T5DecoderTorchFile(t5_model.to('cpu'), metadata)\n",
    "\n",
    "onnx_t5_encoder = t5_encoder.as_onnx_model(\n",
    "    os.path.join(onnx_model_path, encoder_onnx_model_fpath), force_overwrite=False\n",
    ")\n",
    "onnx_t5_decoder = t5_decoder.as_onnx_model(\n",
    "    os.path.join(onnx_model_path, decoder_onnx_model_fpath), force_overwrite=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7baf007e-5508-485c-a87f-9bfe16260452",
   "metadata": {},
   "source": [
    "<a id=\"3\"></a>\n",
    "\n",
    "## 3. Convert to TensorRT\n",
    "\n",
    "Now we are ready to parse the ONNX encoder and decoder models and convert them to optimized TensorRT engines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "037ac958-2627-439c-9db5-27640e3f7967",
   "metadata": {},
   "outputs": [],
   "source": [
    "from T5.export import T5DecoderONNXFile, T5EncoderONNXFile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6bd6e3fc-6797-46b0-a211-ce42d3769105",
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorrt_model_path = './models/{}/tensorrt'.format(T5_VARIANT)\n",
    "!mkdir -p tensorrt_model_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "cfb64120-9012-40c8-b1e2-4a6366b71294",
   "metadata": {},
   "outputs": [],
   "source": [
    "t5_trt_encoder_engine = T5EncoderONNXFile(\n",
    "                os.path.join(onnx_model_path, encoder_onnx_model_fpath), metadata\n",
    "            ).as_trt_engine(os.path.join(tensorrt_model_path, encoder_onnx_model_fpath) + \".engine\")\n",
    "\n",
    "t5_trt_decoder_engine = T5DecoderONNXFile(\n",
    "                os.path.join(onnx_model_path, decoder_onnx_model_fpath), metadata\n",
    "            ).as_trt_engine(os.path.join(tensorrt_model_path, decoder_onnx_model_fpath) + \".engine\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74f7f6fc-1e6a-4ddc-8e9b-543d9e8dab4d",
   "metadata": {},
   "source": [
    "### Inference with TensorRT engine\n",
    "\n",
    "Great, if you have reached this stage, it means we now have an optimized TensorRT engine for the T5 model, ready for us to carry out inference. \n",
    "\n",
    "#### Single example inference\n",
    "The T5 model with TensorRT backend can now be employed in place of the original HuggingFace T5 model.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3954f2f4-c393-463b-a44b-3e5335032b57",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize TensorRT engines\n",
    "from T5.trt import T5TRTEncoder, T5TRTDecoder\n",
    "\n",
    "tfm_config = T5Config(\n",
    "    use_cache=True,\n",
    "    num_layers=T5ModelTRTConfig.NUMBER_OF_LAYERS[T5_VARIANT],\n",
    ")\n",
    "    \n",
    "t5_trt_encoder = T5TRTEncoder(\n",
    "                t5_trt_encoder_engine, metadata, tfm_config\n",
    "            )\n",
    "t5_trt_decoder = T5TRTDecoder(\n",
    "                t5_trt_decoder_engine, metadata, tfm_config\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a9544ecb-2671-4b53-a544-08f13424cefe",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inference on a single sample\n",
    "encoder_last_hidden_state = t5_trt_encoder(input_ids=input_ids)\n",
    "outputs = t5_trt_decoder(input_ids, encoder_last_hidden_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8d71a327-546f-4b5b-bd42-caaffcceafc7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Das ist gut.\n"
     ]
    }
   ],
   "source": [
    "# Generate sequence for an input\n",
    "from transformers.generation_stopping_criteria import (\n",
    "    MaxLengthCriteria,\n",
    "    StoppingCriteriaList,\n",
    ")\n",
    "\n",
    "max_length = 64\n",
    "\n",
    "decoder_input_ids = torch.full(\n",
    "    (1, 1), tokenizer.convert_tokens_to_ids(tokenizer.pad_token), dtype=torch.int32\n",
    ")\n",
    "encoder_last_hidden_state = t5_trt_encoder(input_ids=input_ids)\n",
    "\n",
    "outputs = t5_trt_decoder.greedy_search(\n",
    "            input_ids=decoder_input_ids,\n",
    "            encoder_hidden_states=encoder_last_hidden_state,\n",
    "            stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length)])\n",
    "        )\n",
    "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed9d4a98-b034-470e-a9f8-096d4100b8d4",
   "metadata": {},
   "source": [
    "#### TRT engine inference benchmark: encoder and decoder stacks\n",
    "First, we will bechmark the encoder and decoder stacks as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "70b37591-4398-40ff-8a39-5f75347192dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.000644649553578347"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder_last_hidden_state, encoder_e2e_median_time = encoder_inference(\n",
    "    t5_trt_encoder, input_ids, TimingProfile(10,1,1)\n",
    ")\n",
    "encoder_e2e_median_time\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "7e5459da-a01b-4894-88dc-01b3637ded53",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0014052424812689424"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_, decoder_e2e_median_time = decoder_inference(\n",
    "    t5_trt_decoder, input_ids, encoder_last_hidden_state, TimingProfile(10,1,1)\n",
    ")\n",
    "decoder_e2e_median_time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62ebfe03-7a60-4dd0-ad32-4e53d6012b07",
   "metadata": {},
   "source": [
    "### Full model inference benchmark\n",
    "\n",
    "Next, we will try the full TensorRT T5 engine for the task of translation. As before, note the time difference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f31cb550-24b9-48cd-a4ec-0bf18ac5e40c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Das ist gut.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.007905778475105762"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decoder_output_greedy, full_e2e_median_runtime = full_inference_greedy(\n",
    "    t5_trt_encoder,\n",
    "    t5_trt_decoder,\n",
    "    input_ids,\n",
    "    tokenizer,\n",
    "     TimingProfile(10,1,1),\n",
    "    max_length=T5ModelTRTConfig.MAX_SEQUENCE_LENGTH[metadata.variant],\n",
    "    use_cuda=False,\n",
    ")\n",
    "\n",
    "print(tokenizer.decode(decoder_output_greedy[0], skip_special_tokens=True))\n",
    "full_e2e_median_runtime\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92031643-8ee8-4d50-864b-a08e4d551dc6",
   "metadata": {},
   "source": [
    "You can now compare the output of the original PyTorch model and the TensorRT engine. Notice the speed difference. On an NVIDIA V100 32GB GPU, this results in upto ~10x performance improvement (from 0.0802s to 0.0082s for the T5-small variant)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a1f5dca-397c-4c8c-9200-61b30cdba824",
   "metadata": {},
   "source": [
    "## Conclusion and where-to next?\n",
    "\n",
    "This notebook has walked you through the process of converting a HuggingFace PyTorch T5 model to an optimized TensorRT engine for inference in 3 easy steps. The TensorRT inference engine can be conviniently used as a drop-in replacement for the orginial HuggingFace T5 model while providing significant speed up. \n",
    "\n",
    "If you are interested in further details of the conversion process, check out [T5/trt.py](../T5/trt.py)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
